# -*- coding: utf-8 -*-
"""
nasbench301.patches
-------------------
Centralized implementation:
    1) ConfigSpace decoder patches (NormalFloat/NormalInteger; log-scale boundary safety)
    2) JSON cleaning: replace "default" -> "default_value", and use new CS API to read
    3) Patch NB301's ConfigLoader to use the above "cleaned reading"
    4) Align nb_api.fixed_hyperparameters with ConfigSpace (legalize/fallback to default values)
    5) Provide convenient functions to load models and ConfigSpace (local models directory)
Patches are automatically applied when this module is imported (no download triggered); model download only happens when load_models() is called.
"""

from __future__ import annotations
import os, re, importlib
from io import StringIO
from typing import Any, Dict, Tuple
from pathlib import Path

# ----- ConfigSpace patch: decoders -----
import ConfigSpace.read_and_write.dictionary as cs_dict
import ConfigSpace.read_and_write.json as cs_json
import ConfigSpace.configuration_space as cs_cs
from ConfigSpace import ConfigurationSpace
from ConfigSpace.hyperparameters import (
        NormalFloatHyperparameter, NormalIntegerHyperparameter,
        Constant, CategoricalHyperparameter, IntegerHyperparameter, FloatHyperparameter
)
from ConfigSpace.read_and_write.dictionary import _backwards_compat


def _choose_bounds_float(itm: dict):
        name = itm.get('name', '').lower()
        mu   = float(itm.get('mu', 1.0))
        log  = bool(itm.get('log', False))
        lower = itm.get('lower', None)
        upper = itm.get('upper', None)

        if log:
                if 'learning_rate' in name or name.endswith(':lr') or ':lr' in name:
                        if lower is None or lower <= 0: lower = 1e-5
                        if upper is None or upper <= lower: upper = 1.0
                elif 'weight_decay' in name or 'wd' in name or 'decay' in name:
                        if lower is None or lower <= 0: lower = 1e-6
                        if upper is None or upper <= lower: upper = 1e-1
                else:
                        if lower is None or lower <= 0: lower = 1e-12
                        cand = max(10.0 * abs(mu), 1.0)
                        if upper is None or upper <= lower: upper = cand
        else:
                if lower is None: lower = float('-inf')
                if upper is None: upper = float('inf')
        return lower, upper


def _choose_bounds_int(itm: dict):
        log  = bool(itm.get('log', False))
        lower = itm.get('lower', None)
        upper = itm.get('upper', None)
        if log:
                if lower is None or lower < 1: lower = 1
                if upper is None or upper <= lower:
                        mu = int(round(itm.get('mu', 10)))
                        upper = max(mu, lower + 10)
        else:
                if lower is None: lower = -2**31
                if upper is None or upper <= lower: upper = 2**31 - 1
        return int(lower), int(upper)


def _patched_decode_normal_float(item, cs=None, decoder=None):
        itm = _backwards_compat(dict(item))
        l, u = _choose_bounds_float(itm)
        itm['lower'] = l
        itm['upper'] = u
        return NormalFloatHyperparameter(**itm)


def _patched_decode_normal_int(item, cs=None, decoder=None):
        itm = _backwards_compat(dict(item))
        l, u = _choose_bounds_int(itm)
        itm['lower'] = l
        itm['upper'] = u
        return NormalIntegerHyperparameter(**itm)


def patch_configspace_decoders(verbose: bool = True):
        cs_dict._decode_normal_float = _patched_decode_normal_float
        if hasattr(cs_dict, "_decode_normal_int"):
                cs_dict._decode_normal_int = _patched_decode_normal_int
        if hasattr(cs_dict, "HYPERPARAMETER_DECODERS"):
                for key in ['normal_float', 'normal', 'normalfloat']:
                        if key in cs_dict.HYPERPARAMETER_DECODERS:
                                cs_dict.HYPERPARAMETER_DECODERS[key] = _patched_decode_normal_float
                for key in ['normal_int', 'normalinteger', 'normal_integer']:
                        if key in cs_dict.HYPERPARAMETER_DECODERS and hasattr(cs_dict, "_decode_normal_int"):
                                cs_dict.HYPERPARAMETER_DECODERS[key] = _patched_decode_normal_int
        importlib.reload(cs_json)
        importlib.reload(cs_cs)
        if verbose:
                print("✓ Patched ConfigSpace decoders; modules reloaded.")


# ----- JSON cleaning and new API -----
def load_cs_sanitized(json_path: str) -> ConfigurationSpace:
        with open(json_path, "r") as f:
                js = f.read()
        js = re.sub(r'"default"\s*:', '"default_value":', js)
        return ConfigurationSpace.from_json(StringIO(js))


# ----- NB301 ConfigLoader patch -----
def patch_nb301_config_loader(verbose: bool = True):
        import nasbench301 as nb
        from nasbench301.surrogate_models import utils as nb_utils
        def _patched_load_config_space(self, path: str):
                return load_cs_sanitized(path)
        nb_utils.ConfigLoader.load_config_space = _patched_load_config_space
        if verbose:
                print("✓ Patched NB301 ConfigLoader -> sanitized JSON + new CS API.")


# ----- Fixed hyperparameter legalization -----
def _as_bool(x):
        if isinstance(x, str):
                if x.lower() == "true":  return True
                if x.lower() == "false": return False
        return x


def harmonize_fixed_to_cs(fixed: dict, cs: ConfigurationSpace) -> dict:
        fixed = dict(fixed)
        for k, v in list(fixed.items()):
                fixed[k] = _as_bool(v)
        for hp in cs.values():
                name = hp.name
                if isinstance(hp, Constant):
                        fixed[name] = hp.value
                elif isinstance(hp, CategoricalHyperparameter) and len(hp.choices) == 1:
                        fixed[name] = hp.choices[0]
        for hp in cs.values():
                name = hp.name
                if name in fixed:
                        v = fixed[name]
                        try:
                                legal = hp.legal_value(v)
                        except Exception:
                                legal = False
                        if not legal:
                                if isinstance(hp, (IntegerHyperparameter, FloatHyperparameter)) and isinstance(v, str):
                                        try:
                                                v2 = int(v) if isinstance(hp, IntegerHyperparameter) else float(v)
                                                if hp.legal_value(v2):
                                                        fixed[name] = v2
                                                        continue
                                        except Exception:
                                                pass
                                dv = getattr(hp, "default_value", None)
                                if dv is not None and hp.legal_value(dv):
                                        fixed[name] = dv
                                else:
                                        fixed.pop(name, None)
        return fixed


def load_configspace_with_patches() -> ConfigurationSpace:
        import nasbench301 as nb
        import nasbench301.api as nb_api
        current_dir = os.path.dirname(os.path.abspath(__file__))
        local_cs = os.path.join(current_dir, 'configspace.json')
        cs_path = local_cs if os.path.exists(local_cs) else os.path.join(os.path.dirname(nb.__file__), 'configspace.json')
        cs = load_cs_sanitized(cs_path)
        nb_api.fixed_hyperparameters = harmonize_fixed_to_cs(nb_api.fixed_hyperparameters, cs)
        # sanity log
        invalid = []
        for hp in cs.values():
                n = hp.name
                if n in nb_api.fixed_hyperparameters and not hp.legal_value(nb_api.fixed_hyperparameters[n]):
                        invalid.append((n, nb_api.fixed_hyperparameters[n]))
        if invalid:
                print("WARN: invalid fixed after harmonize:", invalid)
        else:
                print("✓ nb_api.fixed_hyperparameters harmonized with ConfigSpace.")
        return cs


# ----- Model download and loading (local directory) -----
def ensure_models_downloaded(version: str = "1.0") -> Dict[str, str]:
        import nasbench301 as nb
        current_dir = os.path.dirname(os.path.abspath(__file__))
        models_dir = os.path.join(current_dir, f'nb_models_{version}')
        model_paths = {
                'xgb'        : os.path.join(models_dir, 'xgb_v1.0'),
                'gnn_gin'    : os.path.join(models_dir, 'gnn_gin_v1.0'),
                'lgb_runtime': os.path.join(models_dir, 'lgb_runtime_v1.0'),
        }
        if not all(os.path.exists(p) for p in model_paths.values()):
                print(f"Downloading NB301 models (v{version}) to:", current_dir)
                nb.download_models(version=version, delete_zip=True, download_dir=current_dir)
        return model_paths


def load_models(version: str = "1.0"):
        import nasbench301 as nb
        paths = ensure_models_downloaded(version=version)
        print("==> Loading performance surrogate (XGB)...")
        perf_model = nb.load_ensemble(paths['xgb'])
        print("==> Loading runtime surrogate (LGB)...")
        rt_model   = nb.load_ensemble(paths['lgb_runtime'])
        return perf_model, rt_model


# ----- One-click patch application -----
def apply_all_patches(verbose: bool = True):
        patch_configspace_decoders(verbose=verbose)
        patch_nb301_config_loader(verbose=verbose)
        return True


# Lightweight patches are automatically applied upon import (no model download)
_AUTO_APPLIED = apply_all_patches(verbose=True)

__all__ = [
        "apply_all_patches",
        "patch_configspace_decoders",
        "patch_nb301_config_loader",
        "load_cs_sanitized",
        "load_configspace_with_patches",
        "harmonize_fixed_to_cs",
        "ensure_models_downloaded",
        "load_models",
]

